# -*- coding: utf-8 -*-
"""
02_generate_tables.py

It reads from pre-computed model outputs and raw data files located in the /data/
directory and outputs formatted .tex and .csv files to the /tables/ directory.

This script is intended to be run from the root of the replication package
by the master 'run.sh' script.

Date: July 17, 2025
Author: Shreyas Meher
"""

# ==============================================================================
# SECTION 1: IMPORTS
# ==============================================================================
import pandas as pd
import numpy as np
import json
import ast
import os
from sklearn.metrics import (
    classification_report, 
    precision_recall_fscore_support, 
    accuracy_score, 
    hamming_loss,
    roc_auc_score,
    f1_score
)
from sklearn.preprocessing import MultiLabelBinarizer, LabelEncoder, LabelBinarizer
from seqeval.metrics import classification_report as seqeval_report

# ==============================================================================
# SECTION 2: HELPER FUNCTIONS
# ==============================================================================

def convert_entities_to_bio(text_tokens, entities):
    """
    Converts a list of extracted entities (e.g., from an LLM) into a token-level
    BIO-tagged sequence for evaluation with seqeval.

    Args:
        text_tokens (list): A list of tokens from the original text.
        entities (list): A list of entity dictionaries, where each dict has
                         'entity_text' and 'entity_label'.

    Returns:
        list: A list of BIO tags corresponding to the text_tokens.
    """
    labels = ['O'] * len(text_tokens)
    if not isinstance(entities, list):
        return labels
    
    # Sort by length (longest first) to handle nested entities correctly
    entities = sorted(entities, key=lambda x: len(x.get('entity_text', '')), reverse=True)
    
    for entity in entities:
        entity_text_tokens = entity.get('entity_text', '').split()
        entity_label = entity.get('entity_label', '').replace('-', '').replace(' ', '')
        if not entity_text_tokens or not entity_label:
            continue
            
        for i in range(len(text_tokens) - len(entity_text_tokens) + 1):
            if text_tokens[i:i+len(entity_text_tokens)] == entity_text_tokens:
                if all(labels[j] == 'O' for j in range(i, i + len(entity_text_tokens))):
                    labels[i] = f"B-{entity_label}"
                    for j in range(1, len(entity_text_tokens)):
                        labels[i+j] = f"I-{entity_label}"
                    break # Move to the next entity once a match is found and labeled
    return labels

def parse_prediction(pred_str, threshold=0.5):
    """
    Parses various prediction formats from the multi-label classification task,
    handling both JSON and pipe-separated strings.

    Args:
        pred_str (str): The raw prediction string from the model.
        threshold (float): The probability threshold for including a label in multi-label results.

    Returns:
        tuple: A tuple containing the primary prediction (str) and a list of all
               predictions meeting the threshold (list).
    """
    primary_prediction, multi_label_list = "Unknown", ["Unknown"]
    if not isinstance(pred_str, str) or pd.isna(pred_str):
        return primary_prediction, multi_label_list
    try:
        pred_dict = json.loads(pred_str.replace("'", '"'))
        if isinstance(pred_dict, dict):
            valid_preds = [(k, float(v)) for k, v in pred_dict.items() if v is not None and str(v).replace('.', '', 1).isdigit()]
            if valid_preds:
                valid_preds.sort(key=lambda x: x[1], reverse=True)
                primary_prediction = valid_preds[0][0]
                multi_label_list = [pred[0] for pred in valid_preds if pred[1] >= threshold]
    except (json.JSONDecodeError, TypeError, ValueError):
        clean_str = pred_str.split('\n')[0]
        if "|" in clean_str:
            labels = [label.strip() for label in clean_str.split("|") if label.strip()]
            if labels:
                primary_prediction, multi_label_list = labels[0], labels
        elif clean_str:
            primary_prediction, multi_label_list = clean_str, [clean_str]
    return primary_prediction, (multi_label_list if multi_label_list else ["Unknown"])

def get_gold_standard_labels(row):
    """
    Gets ground truth for multi-label classification by checking all three
    attack type columns in the raw data.

    Args:
        row (pd.Series): A row from the raw data DataFrame.

    Returns:
        list: A list of ground truth labels for the event.
    """
    labels = set(row[col] for col in ['attacktype1_txt', 'attacktype2_txt', 'attacktype3_txt'] if pd.notna(row[col]) and row[col] and row[col] != "Unknown")
    return list(labels) if labels else ["Unknown"]

def extract_main_prediction(json_str):
    """
    Safely extracts the attack type with the highest score from a JSON string
    for the mBERT comparison task.

    Args:
        json_str (str): A string potentially containing a JSON object.

    Returns:
        str: The label with the highest score, or 'Unknown'.
    """
    try:
        prediction_dict = ast.literal_eval(json_str)
        if isinstance(prediction_dict, dict) and prediction_dict:
            return max(prediction_dict, key=prediction_dict.get)
    except (ValueError, SyntaxError):
        pass
    return "Unknown"

def get_probabilities(series, classes, is_json=True):
    """
    Extracts probability scores for each class to enable AUC calculation.
    For single-label predictions, it creates a one-hot distribution.

    Args:
        series (pd.Series): The column of model predictions.
        classes (list): The master list of all possible classes.
        is_json (bool): Flag indicating if the prediction is in JSON format.

    Returns:
        np.array: An array of probability scores.
    """
    y_scores, class_map = [], {cls: i for i, cls in enumerate(classes)}
    for entry in series:
        scores = np.zeros(len(classes))
        if pd.notna(entry):
            if is_json:
                try:
                    pred_dict = ast.literal_eval(entry)
                    if isinstance(pred_dict, dict):
                        for label, score in pred_dict.items():
                            if label in class_map:
                                scores[class_map[label]] = float(score)
                except (ValueError, SyntaxError):
                    pass
            elif entry in class_map:
                scores[class_map[entry]] = 1.0
        y_scores.append(scores)
    return np.array(y_scores)


# ==============================================================================
# SECTION 3: TABLE GENERATION FUNCTIONS
# ==============================================================================

def generate_bc_and_ner_tables(bc_input_csv, ner_input_csv, output_dir):
    """
    Generates tables related to the Binary Classification and NER tasks.
    - Table 2: BC Performance (.tex)
    - Table 3: NER Summary Performance (.tex)
    - Table 9: NER Detailed Appendix Table (.tex)
    """
    print("Generating Tables 2, 3, and 9...")

    # --- Generate Table 2: Binary Classification ---
    df_bc = pd.read_csv(bc_input_csv)
    y_true_bc = df_bc['bin'].apply(lambda x: 'Conflict' if x == 1 else 'Not Conflict')
    bc_lines = []
    bc_models = {'ConfliBERT': 'conflibert', 'Gemma 2 (9B)': 'gemma', 'Llama 3.1 (8B)': 'llama'}
    for name, col in bc_models.items():
        y_pred_bc = df_bc[col].apply(lambda x: 'Not Conflict' if isinstance(x, str) and ('not' in x.lower() or 'non' in x.lower()) else 'Conflict')
        report = classification_report(y_true_bc, y_pred_bc, output_dict=True, zero_division=0)
        c, w = report.get('Conflict', {}), report.get('weighted avg', {})
        bc_lines.append(f"& Conflict & {c.get('precision', 0):.4f} & {c.get('recall', 0):.4f} & {c.get('f1-score', 0):.4f} & {int(c.get('support', 0))} \\\\")
        bc_lines.append(f"\\multirow{{-2}}{{*}}{{{name}}} & Weighted Avg & {w.get('precision', 0):.4f} & {w.get('recall', 0):.4f} & {w.get('f1-score', 0):.4f} & {int(w.get('support', 0))} \\\\")
        if name != 'Llama 3.1 (8B)':
            bc_lines.append("\\midrule")
    
    table2_tex = f"""\\begin{{table}}[tp]
\\centering
\\caption{{Performance Metrics for Binary Classifications.}}
\\label{{tab:binary_metrics}}
\\begin{{tabular}}{{llrrrr}}
\\toprule
Model & Class & Precision & Recall & F1-score & Support \\\\
\\midrule
{chr(10).join(bc_lines)}
\\bottomrule
\\end{{tabular}}
\\end{{table}}"""
    with open(os.path.join(output_dir, 'table2.tex'), 'w') as f:
        f.write(table2_tex)

    # --- Calculate NER Metrics (for Tables 3 and 9) ---
    df_ner = pd.read_csv(ner_input_csv)
    df_ner['gold_labels_bio'] = df_ner['gold_labels_bio'].apply(ast.literal_eval)
    y_true_ner = df_ner['gold_labels_bio'].tolist()
    all_reports = {}
    ner_models = {'ConfliBERT': 'conflibert', 'Gemma 2 (9B)': 'gemma', 'Llama 3.1 (8B)': 'llama'}
    for name, col in ner_models.items():
        y_pred_ner = [convert_entities_to_bio(row['text'].split(), json.loads(row[f'{col}_entities'])) for i, row in df_ner.iterrows()]
        y_true_aligned, y_pred_aligned = [list(t) for t in zip(*[(yt, yp) for yt, yp in zip(y_true_ner, y_pred_ner) if len(yt) == len(yp)])]
        all_reports[name] = seqeval_report(y_true_aligned, y_pred_aligned, output_dict=True, zero_division=0)
    
    # --- Generate Table 3: NER Summary ---
    table3_lines = []
    for name, col in ner_models.items():
        report = all_reports[name]
        micro, weighted = report.get('micro avg', {}), report.get('weighted avg', {})
        table3_lines.append(f"& Micro Avg & {micro.get('precision', 0):.4f} & {micro.get('recall', 0):.4f} & {micro.get('f1-score', 0):.4f} & {int(micro.get('support', 0))} \\\\")
        table3_lines.append(f"\\multirow{{-2}}{{*}}{{{name}}} & Weighted Avg & {weighted.get('precision', 0):.4f} & {weighted.get('recall', 0):.4f} & {weighted.get('f1-score', 0):.4f} & {int(weighted.get('support', 0))} \\\\")
        if name != 'Llama 3.1 (8B)':
            table3_lines.append("\\midrule")
    
    table3_tex = f"""\\begin{{table}}[tp]
\\centering
\\caption{{Overall Performance Metrics for Named Entity Recognition.}}
\\label{{tab:ner_metrics}}
\\begin{{tabular}}{{llrrrr}}
\\toprule
Model & Class & Precision & Recall & F1-score & Support \\\\
\\midrule
{chr(10).join(table3_lines)}
\\bottomrule
\\end{{tabular}}
\\end{{table}}"""
    with open(os.path.join(output_dir, 'table3.tex'), 'w') as f:
        f.write(table3_tex)
    
    # --- Generate Table 9: NER Appendix ---
    generate_ner_appendix_table(all_reports, os.path.join(output_dir, 'table9.tex'))


def generate_ner_appendix_table(all_reports, output_file):
    """Generates the large appendix longtable (Table 9) with all NER details."""
    all_classes = sorted(list(set(cls for report in all_reports.values() for cls in report if len(cls) > 3)))
    summary_metrics = ['micro avg', 'macro avg', 'weighted avg']
    
    tex_string = """\\begin{longtable}{llrrrr}
\\caption{Full Per-Class Performance Metrics for Named Entity Recognition Models.} \\label{tab:ner_appendix_full_long} \\\\
\\toprule
\\textbf{Model} & \\textbf{Entity Class} & \\textbf{Precision} & \\textbf{Recall} & \\textbf{F1-Score} & \\textbf{Support} \\\\
\\midrule
\\endfirsthead
\\caption[]{Full Per-Class Performance Metrics (continued)} \\\\
\\toprule
\\textbf{Model} & \\textbf{Entity Class} & \\textbf{Precision} & \\textbf{Recall} & \\textbf{F1-Score} & \\textbf{Support} \\\\
\\midrule
\\endhead
\\bottomrule
\\endfoot
\\bottomrule
\\endlastfoot
"""
    models_to_render = {'ConfliBERT': 'ConfliBERT', 'Gemma 2 (9B)': 'Gemma 2 (9B)', 'Llama 3.1 (8B)': 'Llama 3.1 (8B)'}
    
    for i, (name, key) in enumerate(models_to_render.items()):
        report = all_reports[key]
        tex_string += f"\\multicolumn{{6}}{{l}}{{\\textbf{{{name}}}}} \\\\\n\\midrule\n"
        for entity in all_classes:
            stats = report.get(entity)
            if stats:
                p, r, f1, s = stats['precision'], stats['recall'], stats['f1-score'], stats['support']
                tex_string += f"& {entity.replace('_', ' ')} & {p:.4f} & {r:.4f} & {f1:.4f} & {int(s)} \\\\\n"
        tex_string += "\\midrule\n"
        for metric_key in summary_metrics:
            metric_name = metric_key.replace('_', ' ').title()
            stats = report.get(metric_key, {})
            p, r, f1, s = stats.get('precision', 0), stats.get('recall', 0), stats.get('f1-score', 0), stats.get('support', 0)
            tex_string += f"& \\textit{{{metric_name}}} & \\textit{{{p:.4f}}} & \\textit{{{r:.4f}}} & \\textit{{{f1:.4f}}} & \\textit{{{int(s)}}} \\\\\n"
        if i < len(models_to_render) - 1:
            tex_string += "\\midrule\n\n"
            
    tex_string += "\\end{longtable}\n"
    with open(output_file, 'w') as f:
        f.write(tex_string)

def generate_multilabel_tables(input_csv, output_dir):
    """Generates Tables 5 & 6 as .csv files for verification."""
    print("Generating Tables 5 & 6...")
    df = pd.read_csv(input_csv, low_memory=False)
    model_map = {'eventdata-utd/gtd-2016_predictions': 'ConfliBERT', 'hf.co/shreyasmeher/ConflLlama:Q4_K_M_predictions': 'ConflLlama-Q4KM', 'hf.co/shreyasmeher/ConflLlama:Q8_0_predictions': 'ConflLlama-Q8', 'gemma2:latest_predictions': 'Gemma 2', 'llama3.1:latest_predictions': 'Llama 3.1', 'qwen2.5:14b_predictions': 'Qwen 2.5'}
    y_true_single, y_true_multi = df['attacktype1_txt'].fillna('Unknown'), df.apply(get_gold_standard_labels, axis=1)
    mlb = MultiLabelBinarizer().fit(y_true_multi)
    y_true_binarized = mlb.transform(y_true_multi)
    table_5_data, table_6_data = [], []
    for col, name in model_map.items():
        if col not in df.columns: continue
        primary_preds, multi_preds = zip(*df[col].apply(parse_prediction))
        p, r, f1, _ = precision_recall_fscore_support(y_true_single, primary_preds, average='macro', zero_division=0)
        table_5_data.append({'Model': name, 'Accuracy': accuracy_score(y_true_single, primary_preds), 'Precision': p, 'Recall': r, 'F1': f1})
        y_pred_binarized = mlb.transform(multi_preds)
        table_6_data.append({'Model': name, 'Subset Accuracy (%)': accuracy_score(y_true_binarized, y_pred_binarized) * 100, 'Hamming Loss': hamming_loss(y_true_binarized, y_pred_binarized), 'Partial Match (%)': np.mean(np.sum(y_true_binarized * y_pred_binarized, axis=1) > 0) * 100, 'Predicted Label Cardinality': np.mean(np.sum(y_pred_binarized, axis=1))})
    pd.DataFrame(table_5_data).to_csv(os.path.join(output_dir, "table5.csv"), index=False, float_format='%.4f')
    pd.DataFrame(table_6_data).to_csv(os.path.join(output_dir, "table6.csv"), index=False, float_format='%.4f')

def generate_mbert_comparison_tables(input_csv, output_dir):
    """Generates Table 7 as a formatted .tex file and Table 8 as a .csv file."""
    print("Generating Tables 7 & 8...")
    df = pd.read_csv(input_csv, low_memory=False)
    model_cols, results = {'ConfliBERT': 'eventdata-utd/gtd-2016_predictions', 'Confli-mBERT': 'confli-mbert_predictions'}, {}
    y_true = df['attacktype1_txt'].fillna('Unknown')
    le = LabelEncoder().fit(y_true)
    all_labels = le.classes_
    y_true_encoded, y_true_binarized = le.transform(y_true), LabelBinarizer().fit(le.classes_).transform(y_true)
    for name, col in model_cols.items():
        is_json = 'mBERT' not in name
        y_pred_raw = df[col].apply(extract_main_prediction) if is_json else df[col].fillna('Unknown')
        y_pred_safe = y_pred_raw.apply(lambda x: x if x in le.classes_ else 'Unknown')
        y_pred_encoded = le.transform(y_pred_safe)
        y_scores = get_probabilities(df[col], classes=all_labels, is_json=is_json)
        f1_pc = f1_score(y_true_encoded, y_pred_encoded, average=None, labels=range(len(all_labels)), zero_division=0)
        results[name] = {'acc': accuracy_score(y_true_encoded, y_pred_encoded), 'f1_per_class': f1_pc, 'avg_f1': np.mean(f1_pc), 'avg_auc': roc_auc_score(y_true_binarized, y_scores, average='macro', multi_class='ovr')}
    
    diff_acc = results['ConfliBERT']['acc'] - results['Confli-mBERT']['acc']
    diff_f1 = results['ConfliBERT']['avg_f1'] - results['Confli-mBERT']['avg_f1']
    diff_auc = results['ConfliBERT']['avg_auc'] - results['Confli-mBERT']['avg_auc']
    table7_tex = f"""\\begin{{table}}[h]
\\centering
\\caption{{Overall Performance Metrics}} \\label{{tab:overall_perf}}
\\begin{{tabular}}{{lccc}}
\\toprule
\\textbf{{Metric}} & \\textbf{{ConfliBERT}} & \\textbf{{Confli-mBERT}} & \\textbf{{Difference}} \\\\ \\midrule
Overall Accuracy & {results['ConfliBERT']['acc'] * 100:.2f}\\% & {results['Confli-mBERT']['acc'] * 100:.2f}\\% & {diff_acc*100:+.2f}\\% \\\\
Average F1 (all types) & {results['ConfliBERT']['avg_f1']:.4f} & {results['Confli-mBERT']['avg_f1']:.4f} & {diff_f1:+.4f} \\\\
Average AUC (all types) & {results['ConfliBERT']['avg_auc']:.4f} & {results['Confli-mBERT']['avg_auc']:.4f} & {diff_auc:+.4f} \\\\
\\bottomrule
\\end{{tabular}}
\\end{{table}}"""
    with open(os.path.join(output_dir, 'table7.tex'), 'w') as f: f.write(table7_tex)

    table_8_df = pd.DataFrame({'Attack Type': all_labels, 'Prevalence': y_true.value_counts(normalize=True).reindex(all_labels, fill_value=0), 'ConfliBERT': results['ConfliBERT']['f1_per_class'], 'Confli-mBERT': results['Confli-mBERT']['f1_per_class']})
    table_8_df['Difference'] = table_8_df['ConfliBERT'] - table_8_df['Confli-mBERT']
    table_8_df.to_csv(os.path.join(output_dir, "table8.csv"), index=False, float_format='%.4f')

# ==============================================================================
# SECTION 4: MAIN EXECUTION BLOCK
# ==============================================================================

if __name__ == "__main__":
    print("--- Starting All Table Generation ---")
    
    BC_MODEL_OUTPUTS = 'data/model_outputs/test_classification_results.csv'
    NER_MODEL_OUTPUTS = 'data/model_outputs/ner_inference_results.csv'
    MULTI_LABEL_RAW_DATA = 'data/raw/raw_gtd_multilabel_data.csv'
    MBERT_RAW_DATA = 'data/raw/raw_mbert_comparison_data.csv'
    TABLES_DIR = 'tables'

    # Ensure output directory exists
    os.makedirs(TABLES_DIR, exist_ok=True)

    try:
        generate_bc_and_ner_tables(BC_MODEL_OUTPUTS, NER_MODEL_OUTPUTS, TABLES_DIR)
        generate_multilabel_tables(MULTI_LABEL_RAW_DATA, TABLES_DIR)
        generate_mbert_comparison_tables(MBERT_RAW_DATA, TABLES_DIR)
        print("\n--- All tables generated successfully into the /tables/ directory. ---")
    except FileNotFoundError as e:
        print(f"\n❌ Error: A required data file was not found.")
        print(f"Details: {e}")
        print("Please ensure all raw data and model output files are in the correct /data/ subdirectories.")